[codex] Fix Mamba conv params under fine-grained FSDP gather#4467
[codex] Fix Mamba conv params under fine-grained FSDP gather#4467ilml wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
|
CC'd from DM's Hmm, so the MambaMixer has a Conv1D submodule but the Conv1D submodule's weights are used in Autograd functions during the MambaMixer.forward() pass?
|
MambaLayer (GraphableMegatronModule) was not recognized as an FSDP sharding unit, causing its parameters to remain in the root group and defeating ZeRO-3 param sharding for Mamba and hybrid models. Additionally, MambaMixer sets tensor_model_parallel and partition_dim directly on parameters (conv1d, A_log, dt_bias, D, norm.weight) rather than on the owning module. The TP annotation logic only checked module-level attributes, so these parameters were either unclassified or misclassified by the norm-name fallback (e.g. ExtendedRMSNorm treated as replicated when actually TP-sharded). Changes: - Register MambaLayer in default fsdp_unit_modules (mcore_fsdp_adapter) and sub_modules_to_wrap (torch_fully_sharded_data_parallel) - Add param-level TP attribute fallback in _detect_parallelism_type, placed before the norm-name fallback so TP-sharded norm weights are correctly classified - Pass param through from _annotate_tensor_parallelism - Add tests for param-level TP detection, norm override, and a MambaMixer-like end-to-end annotation test Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Mamba's fused path reads conv1d weights directly instead of calling Conv1d.forward(), so fine-grained Megatron-FSDP never gathered those child parameters before the second forward. Register the conv module as an extra forward-gather source and resolve context-parallel Mamba params from the live mixer object.
5590832 to
19b794b
Compare
|
Also, if we do implement this, I think Torch FSDP2 has something similar for "hitch-hiking" parameters into the AG: That's per-parameter / per-Tensor though, so we're basically doing the same thing here with modules, I suppose. |
cspades
left a comment
There was a problem hiding this comment.
Looks pretty good, per-module parameter hitch-hiking seems reasonable for me since our FSDP units are based on modules as well.
@shjwudp @Autumn1998 @wujingyue in the re-write we should expose this as an API, but this PR defines the feature, is this general enough?
| return self._slice_conv_param(self.conv1d_cp1.weight) | ||
| conv1d = self._mixer.conv1d if self._mixer is not None else self.conv1d_cp1 | ||
| return self._slice_conv_param(conv1d.weight) |
There was a problem hiding this comment.
Have to ask, where/when do we have "stale" references such that we can no longer directly retrieve the weight?
wujingyue
left a comment
There was a problem hiding this comment.
Thanks for the fix!
IIUC,
is wheredef conv1d calls the F.conv1d instead of a submodule.
Instead, could MambaContextParallel's constructor create the submodule from conv1d_cp1? For your reference, https://gitlab-master.nvidia.com/clara-discovery/boltz/-/blob/dev/src/boltz/distributed/model/layers/triangular_attention.py#L1490 is an internal example that adopts this practice for context parallelism.
cc @cspades
| extra_forward_param_modules = getattr(module, "_fsdp_extra_forward_param_modules", ()) | ||
| if isinstance(extra_forward_param_modules, nn.Module): | ||
| extra_forward_param_modules = (extra_forward_param_modules,) | ||
| if extra_forward_param_modules: | ||
| seen_param_ids = {id(param) for param in param_list} | ||
| for extra_module in extra_forward_param_modules: | ||
| for extra_param in extra_module.parameters(): | ||
| if id(extra_param) not in seen_param_ids: | ||
| param_list.append(extra_param) | ||
| seen_param_ids.add(id(extra_param)) | ||
|
|
There was a problem hiding this comment.
There is also a post-forward / post-backward hook that calls this function: release_module_parameters. It needs to be called on the modules in extra_forward_param_modules so we can re-shard them.
Early note, what about the pre-backward param unshard?
We need to ensure that any rogue weights are re-sharded, and un-sharded during the backward pass. Did you check that the Conv-1D weights are re-sharded?
cspades
left a comment
There was a problem hiding this comment.
Still WIP, will approve when all features are implemented!
wujingyue
left a comment
There was a problem hiding this comment.
Request for clarification: did you run into this issue with enable_fine_grained_param_gather_hook? When MFSDP is applied to MambaLayer, it ought to find all parameters under that layer including the ones in conv1d. It wasn't clear to me how MFSDP missed the parameter in the first place. cc @ilml
It's automatically turned on for MXFP8. (I still forget exactly why.) |
yeah this is a very subtle bug: |
|
Similar to what I said earlier, I think the right fix is to make The reasons are:
Wdyt? |

Summary
Fix fine-grained Megatron-FSDP parameter gathering for Mamba's fused conv path.
This branch is now stacked with the MambaLayer FSDP support from #4329 first, then the direct conv-param gather fix. #4329 makes
MambaLayeran FSDP unit and fixes TP annotation for SSM parameters; this PR handles the remaining case whereMambaMixer.forward()reads a child module's parameters without invoking that child module's forward hook.Mamba's memory-efficient fused path reads
conv1d.weightandconv1d.biasdirectly and passes them intomamba_split_conv1d_scan_combined, instead of callingConv1d.forward(). With fine-grained Megatron-FSDP gather enabled, that means the childconv1dmodule's pre-forward gather hook never runs. After the first forward releases parameter storage, the second forward can pass a null-base sharded view into causal-conv, producing an illegal memory access.This PR adds an opt-in
_fsdp_extra_forward_param_moduleshook for modules that directly read child-module params, and uses it fromMambaMixerforself.conv1d. It also makes Mamba context-parallel parameter access resolve through the live mixer object so FSDP-updated parameters are not bypassed by stale cached references.Validation
python3 -m py_compile megatron/core/distributed/fsdp/mcore_fsdp_adapter.py megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py megatron/core/distributed/torch_fully_sharded_data_parallel.py megatron/core/ssm/mamba_context_parallel.py megatron/core/ssm/mamba_mixer.py tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.pypython3 -m pytest tests/unit_tests/distributed/megatron_fsdp/test_mcore_tensor_parallelism_detect.py -qcould not run on the login node becausepytestis not installed there./home/tolong/work/dsv3/interactive_nt.shcompleted through iteration 50 and saved checkpoint with noCUDA error,illegal memory, orFAILEDin the log:/lustre/fsw/coreai_dlalgo_llm/tolong/results/nemo_megatron/megatron/nemotron6/hybrid/debug/interactive_nt_debug/logs/interactive_nt_full_20260424_150942.log